{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Computing optimal policy using SARSA\n",
    "\n",
    "Now, let's implement SARSA to find the optimal policy in the frozen lake environment.\n",
    "\n",
    "First, let's import the necessary libraries:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gym\n",
    "import pandas as pd\n",
    "import random"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we create the frozen lake environment using gym:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = gym.make('FrozenLake-v0')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's define the dictionary for storing the Q value of the state-action pair and we initialize\n",
    "the Q value of all the state-action pair to 0.0:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "Q = {}\n",
    "for s in range(env.observation_space.n):\n",
    "    for a in range(env.action_space.n):\n",
    "        Q[(s,a)] = 0.0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, let's define the epsilon-greedy policy. We generate a random number from the\n",
    "uniform distribution and if the random number is less than epsilon we select the random\n",
    "action else we select the best action which has the maximum Q value: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def epsilon_greedy(state, epsilon):\n",
    "    if random.uniform(0,1) < epsilon:\n",
    "        return env.action_space.sample()\n",
    "    else:\n",
    "        return max(list(range(env.action_space.n)), key = lambda x: Q[(state,x)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Initialize the discount factor $\\gamma$ and the learning rate $\\alpha$ and epsilon value:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "alpha = 0.85\n",
    "gamma = 0.90\n",
    "epsilon = 0.8"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "Define the number of episodes and number of time steps in the episode:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_episodes = 5000\n",
    "num_timesteps = 1000"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compute the optimal policy using the SARSA update rule as:\n",
    "\n",
    "$$ Q(s,a) = Q(s,a) + \\alpha (r + \\gamma Q(s',a') - Q(s,a)) $$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "#for each episode\n",
    "for i in range(20000):\n",
    "       \n",
    "    #initialize the state by resetting the environment\n",
    "    s = env.reset()\n",
    "    \n",
    "    #select the action using the epsilon-greedy policy\n",
    "    a = epsilon_greedy(s,epsilon)\n",
    "    \n",
    "    #for each step in the episode:\n",
    "    for t in range(num_timesteps):\n",
    "\n",
    "        #perform the selected action and store the next state information: \n",
    "        s_, r, done, _ = env.step(a)\n",
    "        \n",
    "        #select the action a dash in the next state using the epsilon greedy policy:\n",
    "        a_ = epsilon_greedy(s_,epsilon) \n",
    "        \n",
    "        #compute the Q value of the state-action pair\n",
    "        Q[(s,a)] += alpha * (r + gamma * Q[(s_,a_)]-Q[(s,a)])\n",
    "        \n",
    "        #update next state to current state\n",
    "        s = s_\n",
    "        \n",
    "        #update next action to current action\n",
    "        a = a_\n",
    "\n",
    "\n",
    "        #if the current state is the terminal state then break:\n",
    "        if done:\n",
    "            break\n",
    "     "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note that on every iteration we update the Q function. After all the iterations, we will have\n",
    "the optimal Q function. Once we have the optimal Q function then we can extract the\n",
    "optimal policy by selecting the action which has maximum Q value in each state. "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
